{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !wget https://f000.backblazeb2.com/file/malaya-model/bert-bahasa/alxlnet-base-2020-04-10.tar.gz\n",
    "# !tar -zxf alxlnet-base-2020-04-10.tar.gz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "config.json\t\t\t       model.ckpt-300000.index\r\n",
      "model.ckpt-300000.data-00000-of-00001  model.ckpt-300000.meta\r\n"
     ]
    }
   ],
   "source": [
    "!ls alxlnet-base-2020-04-10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !wget https://raw.githubusercontent.com/huseinzol05/Malaya/master/pretrained-model/alxlnet/config/alxlnet-base_config.json\n",
    "# !wget https://raw.githubusercontent.com/huseinzol05/Malaya/master/pretrained-model/preprocess/sp10m.cased.v9.model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/malaya/Malaya/session/relevancy/model_utils.py:334: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import xlnet\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from tqdm import tqdm\n",
    "import model_utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import sentencepiece as spm\n",
    "from prepro_utils import preprocess_text, encode_ids\n",
    "\n",
    "sp_model = spm.SentencePieceProcessor()\n",
    "sp_model.Load('sp10m.cased.v9.model')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "with open('dataset.json') as fopen:\n",
    "    data = json.load(fopen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_X = data['train_X']\n",
    "train_Y = data['train_Y']\n",
    "test_X = data['test_X']\n",
    "test_Y = data['test_Y']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from prepro_utils import preprocess_text, encode_ids\n",
    "\n",
    "def tokenize_fn(text):\n",
    "    text = preprocess_text(text, lower= False)\n",
    "    return encode_ids(sp_model, text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "SEG_ID_A   = 0\n",
    "SEG_ID_B   = 1\n",
    "SEG_ID_CLS = 2\n",
    "SEG_ID_SEP = 3\n",
    "SEG_ID_PAD = 4\n",
    "\n",
    "special_symbols = {\n",
    "    \"<unk>\"  : 0,\n",
    "    \"<s>\"    : 1,\n",
    "    \"</s>\"   : 2,\n",
    "    \"<cls>\"  : 3,\n",
    "    \"<sep>\"  : 4,\n",
    "    \"<pad>\"  : 5,\n",
    "    \"<mask>\" : 6,\n",
    "    \"<eod>\"  : 7,\n",
    "    \"<eop>\"  : 8,\n",
    "}\n",
    "\n",
    "VOCAB_SIZE = 32000\n",
    "UNK_ID = special_symbols[\"<unk>\"]\n",
    "CLS_ID = special_symbols[\"<cls>\"]\n",
    "SEP_ID = special_symbols[\"<sep>\"]\n",
    "MASK_ID = special_symbols[\"<mask>\"]\n",
    "EOD_ID = special_symbols[\"<eod>\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "from unidecode import unidecode\n",
    "\n",
    "def cleaning(string):\n",
    "    string = unidecode(string)\n",
    "    string = re.sub(r'\\w+:\\/{2}[\\d\\w-]+(\\.[\\d\\w-]+)*(?:(?:\\/[^\\s/]*))*', '', string)\n",
    "    string = re.sub(r'[ ]+', ' ', string).strip().split()\n",
    "    string = [w for w in string if w[0] != '@']\n",
    "    return ' '.join(string)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 51761/51761 [00:16<00:00, 3079.12it/s]\n",
      "100%|██████████| 5029/5029 [00:01<00:00, 2708.44it/s]\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "for i in tqdm(range(len(train_X))):\n",
    "    train_X[i] = cleaning(train_X[i])\n",
    "    \n",
    "for i in tqdm(range(len(test_X))):\n",
    "    test_X[i] = cleaning(test_X[i])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "MAX_SEQ_LENGTH = 510\n",
    "\n",
    "def get_tokenize(strings):\n",
    "    input_ids, input_masks, segment_ids = [], [], []\n",
    "    for text in tqdm(strings):\n",
    "        tokens_a = tokenize_fn(text)\n",
    "        if len(tokens_a) > MAX_SEQ_LENGTH - 2:\n",
    "            tokens_a = tokens_a[:(MAX_SEQ_LENGTH - 2)]\n",
    "        \n",
    "        segment_id = [SEG_ID_A] * len(tokens_a)\n",
    "        tokens_a.append(SEP_ID)\n",
    "        tokens_a.append(CLS_ID)\n",
    "        segment_id.append(SEG_ID_A)\n",
    "        segment_id.append(SEG_ID_CLS)\n",
    "        input_mask = [0] * len(tokens_a)\n",
    "        \n",
    "        input_ids.append(tokens_a)\n",
    "        input_masks.append(input_mask)\n",
    "        segment_ids.append(segment_id)\n",
    "\n",
    "    return input_ids, input_masks, segment_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 51761/51761 [03:33<00:00, 242.29it/s]\n",
      "100%|██████████| 5029/5029 [00:24<00:00, 206.57it/s]\n"
     ]
    }
   ],
   "source": [
    "train_input_ids, train_input_masks, train_segment_ids = get_tokenize(train_X)\n",
    "test_input_ids, test_input_masks, test_segment_ids = get_tokenize(test_X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/malaya/Malaya/session/relevancy/xlnet.py:70: The name tf.gfile.Open is deprecated. Please use tf.io.gfile.GFile instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "kwargs = dict(\n",
    "      is_training=True,\n",
    "      use_tpu=False,\n",
    "      use_bfloat16=False,\n",
    "      dropout=0.1,\n",
    "      dropatt=0.1,\n",
    "      init='normal',\n",
    "      init_range=0.1,\n",
    "      init_std=0.05,\n",
    "      clamp_len=-1)\n",
    "\n",
    "xlnet_parameters = xlnet.RunConfig(**kwargs)\n",
    "xlnet_config = xlnet.XLNetConfig(json_path='alxlnet-base_config.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "64701 6470\n"
     ]
    }
   ],
   "source": [
    "epoch = 10\n",
    "batch_size = 8\n",
    "warmup_proportion = 0.1\n",
    "num_train_steps = int(len(train_X) / batch_size * epoch)\n",
    "num_warmup_steps = int(num_train_steps * warmup_proportion)\n",
    "print(num_train_steps, num_warmup_steps)\n",
    "learning_rate = 2e-5\n",
    "\n",
    "training_parameters = dict(\n",
    "      decay_method = 'poly',\n",
    "      train_steps = num_train_steps,\n",
    "      learning_rate = learning_rate,\n",
    "      warmup_steps = num_warmup_steps,\n",
    "      min_lr_ratio = 0.0,\n",
    "      weight_decay = 0.00,\n",
    "      adam_epsilon = 1e-8,\n",
    "      num_core_per_host = 1,\n",
    "      lr_layer_decay_rate = 1,\n",
    "      use_tpu=False,\n",
    "      use_bfloat16=False,\n",
    "      dropout=0.0,\n",
    "      dropatt=0.0,\n",
    "      init='normal',\n",
    "      init_range=0.1,\n",
    "      init_std=0.05,\n",
    "      clip = 1.0,\n",
    "      clamp_len=-1,)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Parameter:\n",
    "    def __init__(self, decay_method, warmup_steps, weight_decay, adam_epsilon, \n",
    "                num_core_per_host, lr_layer_decay_rate, use_tpu, learning_rate, train_steps,\n",
    "                min_lr_ratio, clip, **kwargs):\n",
    "        self.decay_method = decay_method\n",
    "        self.warmup_steps = warmup_steps\n",
    "        self.weight_decay = weight_decay\n",
    "        self.adam_epsilon = adam_epsilon\n",
    "        self.num_core_per_host = num_core_per_host\n",
    "        self.lr_layer_decay_rate = lr_layer_decay_rate\n",
    "        self.use_tpu = use_tpu\n",
    "        self.learning_rate = learning_rate\n",
    "        self.train_steps = train_steps\n",
    "        self.min_lr_ratio = min_lr_ratio\n",
    "        self.clip = clip\n",
    "        \n",
    "training_parameters = Parameter(**training_parameters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model:\n",
    "    def __init__(\n",
    "        self,\n",
    "        dimension_output,\n",
    "        learning_rate = 2e-5,\n",
    "    ):\n",
    "        self.X = tf.placeholder(tf.int32, [None, None])\n",
    "        self.segment_ids = tf.placeholder(tf.int32, [None, None])\n",
    "        self.input_masks = tf.placeholder(tf.float32, [None, None])\n",
    "        self.Y = tf.placeholder(tf.int32, [None])\n",
    "        \n",
    "        xlnet_model = xlnet.XLNetModel(\n",
    "            xlnet_config=xlnet_config,\n",
    "            run_config=xlnet_parameters,\n",
    "            input_ids=tf.transpose(self.X, [1, 0]),\n",
    "            seg_ids=tf.transpose(self.segment_ids, [1, 0]),\n",
    "            input_mask=tf.transpose(self.input_masks, [1, 0]))\n",
    "        \n",
    "        output_layer = xlnet_model.get_sequence_output()\n",
    "        output_layer = tf.transpose(output_layer, [1, 0, 2])\n",
    "        \n",
    "        self.logits_seq = tf.layers.dense(output_layer, dimension_output)\n",
    "        self.logits_seq = tf.identity(self.logits_seq, name = 'logits_seq')\n",
    "        self.logits = self.logits_seq[:, 0]\n",
    "        self.logits = tf.identity(self.logits, name = 'logits')\n",
    "        \n",
    "        self.cost = tf.reduce_mean(\n",
    "            tf.nn.sparse_softmax_cross_entropy_with_logits(\n",
    "                logits = self.logits, labels = self.Y\n",
    "            )\n",
    "        )\n",
    "        \n",
    "        self.optimizer, self.learning_rate, _ = model_utils.get_train_op(training_parameters, self.cost)\n",
    "        \n",
    "        correct_pred = tf.equal(\n",
    "            tf.argmax(self.logits, 1, output_type = tf.int32), self.Y\n",
    "        )\n",
    "        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/malaya/Malaya/session/relevancy/xlnet.py:253: The name tf.variable_scope is deprecated. Please use tf.compat.v1.variable_scope instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/malaya/Malaya/session/relevancy/xlnet.py:253: The name tf.AUTO_REUSE is deprecated. Please use tf.compat.v1.AUTO_REUSE instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/malaya/Malaya/session/relevancy/custom_modeling.py:696: The name tf.logging.info is deprecated. Please use tf.compat.v1.logging.info instead.\n",
      "\n",
      "INFO:tensorflow:memory input None\n",
      "INFO:tensorflow:Use float type <dtype: 'float32'>\n",
      "WARNING:tensorflow:From /home/husein/malaya/Malaya/session/relevancy/custom_modeling.py:703: The name tf.get_variable is deprecated. Please use tf.compat.v1.get_variable instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/malaya/Malaya/session/relevancy/custom_modeling.py:808: dropout (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use keras.layers.dropout instead.\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/layers/core.py:271: Layer.apply (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `layer.__call__` method instead.\n",
      "WARNING:tensorflow:\n",
      "The TensorFlow contrib module will not be included in TensorFlow 2.0.\n",
      "For more information, please see:\n",
      "  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n",
      "  * https://github.com/tensorflow/addons\n",
      "  * https://github.com/tensorflow/io (for I/O related ops)\n",
      "If you depend on functionality not listed there, please file an issue.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/malaya/Malaya/session/relevancy/custom_modeling.py:109: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use keras.layers.Dense instead.\n",
      "WARNING:tensorflow:From /home/husein/malaya/Malaya/session/relevancy/model_utils.py:105: The name tf.train.get_or_create_global_step is deprecated. Please use tf.compat.v1.train.get_or_create_global_step instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/malaya/Malaya/session/relevancy/model_utils.py:119: The name tf.train.polynomial_decay is deprecated. Please use tf.compat.v1.train.polynomial_decay instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/malaya/Malaya/session/relevancy/model_utils.py:136: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n",
      "WARNING:tensorflow:From /home/husein/malaya/Malaya/session/relevancy/model_utils.py:150: The name tf.train.AdamOptimizer is deprecated. Please use tf.compat.v1.train.AdamOptimizer instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "dimension_output = 2\n",
    "\n",
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Model(\n",
    "    dimension_output,\n",
    "    learning_rate\n",
    ")\n",
    "\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "import collections\n",
    "import re\n",
    "\n",
    "def get_assignment_map_from_checkpoint(tvars, init_checkpoint):\n",
    "    \"\"\"Compute the union of the current variables and checkpoint variables.\"\"\"\n",
    "    assignment_map = {}\n",
    "    initialized_variable_names = {}\n",
    "\n",
    "    name_to_variable = collections.OrderedDict()\n",
    "    for var in tvars:\n",
    "        name = var.name\n",
    "        m = re.match('^(.*):\\\\d+$', name)\n",
    "        if m is not None:\n",
    "            name = m.group(1)\n",
    "        name_to_variable[name] = var\n",
    "\n",
    "    init_vars = tf.train.list_variables(init_checkpoint)\n",
    "\n",
    "    assignment_map = collections.OrderedDict()\n",
    "    for x in init_vars:\n",
    "        (name, var) = (x[0], x[1])\n",
    "        if name not in name_to_variable:\n",
    "            continue\n",
    "        assignment_map[name] = name_to_variable[name]\n",
    "        initialized_variable_names[name] = 1\n",
    "        initialized_variable_names[name + ':0'] = 1\n",
    "\n",
    "    return (assignment_map, initialized_variable_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "tvars = tf.trainable_variables()\n",
    "checkpoint = 'alxlnet-base-2020-04-10/model.ckpt-300000'\n",
    "assignment_map, initialized_variable_names = get_assignment_map_from_checkpoint(tvars, \n",
    "                                                                                checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from alxlnet-base-2020-04-10/model.ckpt-300000\n"
     ]
    }
   ],
   "source": [
    "saver = tf.train.Saver(var_list = assignment_map)\n",
    "saver.restore(sess, checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.25, 0.95578206]"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pad_sequences = tf.keras.preprocessing.sequence.pad_sequences\n",
    "\n",
    "i = 0\n",
    "index = 4\n",
    "batch_x = pad_sequences(train_input_ids[i : index], padding='post')\n",
    "batch_y = train_Y[i : index]\n",
    "batch_masks = pad_sequences(train_input_masks[i : index], padding='post', value = 1)\n",
    "batch_segments = pad_sequences(train_segment_ids[i : index], padding='post', value = 4)\n",
    "\n",
    "sess.run(\n",
    "    [model.accuracy, model.cost],\n",
    "    feed_dict = {\n",
    "        model.X: batch_x,\n",
    "        model.Y: batch_y,\n",
    "        model.segment_ids: batch_segments,\n",
    "        model.input_masks: batch_masks,\n",
    "    },\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 6471/6471 [1:13:19<00:00,  1.47it/s, accuracy=0, cost=5.09]     \n",
      "test minibatch loop: 100%|██████████| 629/629 [02:36<00:00,  4.03it/s, accuracy=1, cost=0.0264]   \n",
      "train minibatch loop:   0%|          | 0/6471 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, pass acc: 0.000000, current acc: 0.837043\n",
      "time taken: 4556.105219125748\n",
      "epoch: 0, training loss: 0.574205, training acc: 0.739163, valid loss: 0.483871, valid acc: 0.837043\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 6471/6471 [1:13:15<00:00,  1.47it/s, accuracy=0, cost=6.9]       \n",
      "test minibatch loop: 100%|██████████| 629/629 [02:36<00:00,  4.03it/s, accuracy=0.8, cost=0.785]  \n",
      "train minibatch loop:   0%|          | 0/6471 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1, pass acc: 0.837043, current acc: 0.873688\n",
      "time taken: 4552.077011346817\n",
      "epoch: 1, training loss: 0.379126, training acc: 0.861324, valid loss: 0.505004, valid acc: 0.873688\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 6471/6471 [1:13:16<00:00,  1.47it/s, accuracy=0, cost=7.17]      \n",
      "test minibatch loop: 100%|██████████| 629/629 [02:36<00:00,  4.03it/s, accuracy=0.8, cost=1.51]    \n",
      "train minibatch loop:   0%|          | 0/6471 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 2, pass acc: 0.873688, current acc: 0.886010\n",
      "time taken: 4552.549909830093\n",
      "epoch: 2, training loss: 0.324412, training acc: 0.903763, valid loss: 0.485360, valid acc: 0.886010\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 6471/6471 [1:13:16<00:00,  1.47it/s, accuracy=0, cost=7.09]      \n",
      "test minibatch loop: 100%|██████████| 629/629 [02:36<00:00,  4.03it/s, accuracy=0.8, cost=0.241]   \n",
      "train minibatch loop:   0%|          | 0/6471 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 3, pass acc: 0.886010, current acc: 0.891176\n",
      "time taken: 4552.269495487213\n",
      "epoch: 3, training loss: 0.269619, training acc: 0.930034, valid loss: 0.578096, valid acc: 0.891176\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 6471/6471 [1:13:15<00:00,  1.47it/s, accuracy=0, cost=6.62]      \n",
      "test minibatch loop: 100%|██████████| 629/629 [02:36<00:00,  4.03it/s, accuracy=1, cost=0.000712] \n",
      "train minibatch loop:   0%|          | 0/6471 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 4, pass acc: 0.891176, current acc: 0.892289\n",
      "time taken: 4551.564073085785\n",
      "epoch: 4, training loss: 0.215765, training acc: 0.949216, valid loss: 0.620838, valid acc: 0.892289\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 6471/6471 [1:13:17<00:00,  1.47it/s, accuracy=0, cost=7.24]      \n",
      "test minibatch loop: 100%|██████████| 629/629 [02:36<00:00,  4.02it/s, accuracy=1, cost=0.000583]  \n",
      "train minibatch loop:   0%|          | 0/6471 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 5, pass acc: 0.892289, current acc: 0.896661\n",
      "time taken: 4554.211331129074\n",
      "epoch: 5, training loss: 0.169841, training acc: 0.962062, valid loss: 0.656199, valid acc: 0.896661\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 6471/6471 [1:13:29<00:00,  1.47it/s, accuracy=0, cost=5.32]      \n",
      "test minibatch loop: 100%|██████████| 629/629 [02:36<00:00,  4.01it/s, accuracy=1, cost=3.09e-5]   \n",
      "train minibatch loop:   0%|          | 0/6471 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 6, pass acc: 0.896661, current acc: 0.899245\n",
      "time taken: 4566.070193052292\n",
      "epoch: 6, training loss: 0.133136, training acc: 0.971353, valid loss: 0.665070, valid acc: 0.899245\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 6471/6471 [1:13:24<00:00,  1.47it/s, accuracy=0, cost=5.36]      \n",
      "test minibatch loop: 100%|██████████| 629/629 [02:36<00:00,  4.01it/s, accuracy=1, cost=5.5e-5]    \n",
      "train minibatch loop:   0%|          | 0/6471 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 7, pass acc: 0.899245, current acc: 0.903021\n",
      "time taken: 4561.992794036865\n",
      "epoch: 7, training loss: 0.099682, training acc: 0.978616, valid loss: 0.683677, valid acc: 0.903021\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 6471/6471 [1:13:33<00:00,  1.47it/s, accuracy=0, cost=5.01]      \n",
      "test minibatch loop: 100%|██████████| 629/629 [02:37<00:00,  4.00it/s, accuracy=1, cost=0.0411]    \n",
      "train minibatch loop:   0%|          | 0/6471 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 8, pass acc: 0.903021, current acc: 0.907393\n",
      "time taken: 4571.273866415024\n",
      "epoch: 8, training loss: 0.077829, training acc: 0.983561, valid loss: 0.671531, valid acc: 0.907393\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 6471/6471 [1:13:36<00:00,  1.47it/s, accuracy=0, cost=3.02]      \n",
      "test minibatch loop: 100%|██████████| 629/629 [02:37<00:00,  4.00it/s, accuracy=1, cost=0.00295]  \n",
      "train minibatch loop:   0%|          | 0/6471 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 9, pass acc: 0.907393, current acc: 0.910970\n",
      "time taken: 4574.264142513275\n",
      "epoch: 9, training loss: 0.059021, training acc: 0.987734, valid loss: 0.687154, valid acc: 0.910970\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 6471/6471 [1:13:36<00:00,  1.47it/s, accuracy=0, cost=4.21]     \n",
      "test minibatch loop: 100%|██████████| 629/629 [02:37<00:00,  3.99it/s, accuracy=1, cost=3.03e-6]   "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 4573.7364337444305\n",
      "epoch: 10, training loss: 0.053832, training acc: 0.988197, valid loss: 0.719783, valid acc: 0.908188\n",
      "\n",
      "break epoch:11\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "import time\n",
    "\n",
    "\n",
    "EARLY_STOPPING, CURRENT_CHECKPOINT, CURRENT_ACC, EPOCH = 1, 0, 0, 0\n",
    "\n",
    "while True:\n",
    "    lasttime = time.time()\n",
    "    if CURRENT_CHECKPOINT == EARLY_STOPPING:\n",
    "        print('break epoch:%d\\n' % (EPOCH))\n",
    "        break\n",
    "\n",
    "    train_acc, train_loss, test_acc, test_loss = [], [], [], []\n",
    "    pbar = tqdm(\n",
    "        range(0, len(train_input_ids), batch_size), desc = 'train minibatch loop'\n",
    "    )\n",
    "    for i in pbar:\n",
    "        index = min(i + batch_size, len(train_input_ids))\n",
    "        batch_x = pad_sequences(train_input_ids[i : index], padding='post')\n",
    "        batch_y = train_Y[i : index]\n",
    "        batch_masks = pad_sequences(train_input_masks[i : index], padding='post', value = 1)\n",
    "        batch_segments = pad_sequences(train_segment_ids[i : index], padding='post', value = 4)\n",
    "        acc, cost, _ = sess.run(\n",
    "            [model.accuracy, model.cost, model.optimizer],\n",
    "            feed_dict = {\n",
    "                model.X: batch_x,\n",
    "                model.Y: batch_y,\n",
    "                model.segment_ids: batch_segments,\n",
    "                model.input_masks: batch_masks,\n",
    "            },\n",
    "        )\n",
    "        train_loss.append(cost)\n",
    "        train_acc.append(acc)\n",
    "        pbar.set_postfix(cost = cost, accuracy = acc)\n",
    "        \n",
    "    pbar = tqdm(range(0, len(test_input_ids), batch_size), desc = 'test minibatch loop')\n",
    "    for i in pbar:\n",
    "        index = min(i + batch_size, len(test_input_ids))\n",
    "        batch_x = pad_sequences(test_input_ids[i : index], padding='post')\n",
    "        batch_y = test_Y[i : index]\n",
    "        batch_masks = pad_sequences(test_input_masks[i : index], padding='post', value = 1)\n",
    "        batch_segments = pad_sequences(test_segment_ids[i : index], padding='post', value = 4)\n",
    "        acc, cost = sess.run(\n",
    "            [model.accuracy, model.cost],\n",
    "            feed_dict = {\n",
    "                model.X: batch_x,\n",
    "                model.Y: batch_y,\n",
    "                model.segment_ids: batch_segments,\n",
    "                model.input_masks: batch_masks,\n",
    "            },\n",
    "        )\n",
    "        test_loss.append(cost)\n",
    "        test_acc.append(acc)\n",
    "        pbar.set_postfix(cost = cost, accuracy = acc)\n",
    "        \n",
    "    train_loss = np.mean(train_loss)\n",
    "    train_acc = np.mean(train_acc)\n",
    "    test_loss = np.mean(test_loss)\n",
    "    test_acc = np.mean(test_acc)\n",
    "        \n",
    "    if test_acc > CURRENT_ACC:\n",
    "        print(\n",
    "            'epoch: %d, pass acc: %f, current acc: %f'\n",
    "            % (EPOCH, CURRENT_ACC, test_acc)\n",
    "        )\n",
    "        CURRENT_ACC = test_acc\n",
    "        CURRENT_CHECKPOINT = 0\n",
    "    else:\n",
    "        CURRENT_CHECKPOINT += 1\n",
    "        \n",
    "    print('time taken:', time.time() - lasttime)\n",
    "    print(\n",
    "        'epoch: %d, training loss: %f, training acc: %f, valid loss: %f, valid acc: %f\\n'\n",
    "        % (EPOCH, train_loss, train_acc, test_loss, test_acc)\n",
    "    )\n",
    "    EPOCH += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'alxlnet-base-relevancy/model.ckpt'"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "saver = tf.train.Saver(tf.trainable_variables())\n",
    "saver.save(sess, 'alxlnet-base-relevancy/model.ckpt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "kwargs = dict(\n",
    "      is_training=False,\n",
    "      use_tpu=False,\n",
    "      use_bfloat16=False,\n",
    "      dropout=0.0,\n",
    "      dropatt=0.0,\n",
    "      init='normal',\n",
    "      init_range=0.1,\n",
    "      init_std=0.05,\n",
    "      clamp_len=-1)\n",
    "\n",
    "xlnet_parameters = xlnet.RunConfig(**kwargs)\n",
    "xlnet_config = xlnet.XLNetConfig(json_path='alxlnet-base_config.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:memory input None\n",
      "INFO:tensorflow:Use float type <dtype: 'float32'>\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/client/session.py:1750: UserWarning: An interactive session is already active. This can cause out-of-memory errors in some cases. You must explicitly call `InteractiveSession.close()` to release resources held by the other session(s).\n",
      "  warnings.warn('An interactive session is already active. This can '\n"
     ]
    }
   ],
   "source": [
    "dimension_output = 2\n",
    "learning_rate = 2e-5\n",
    "\n",
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Model(\n",
    "    dimension_output,\n",
    "    learning_rate\n",
    ")\n",
    "\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from alxlnet-base-relevancy/model.ckpt\n"
     ]
    }
   ],
   "source": [
    "saver = tf.train.Saver(tf.trainable_variables())\n",
    "saver.restore(sess, 'alxlnet-base-relevancy/model.ckpt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "test minibatch loop: 100%|██████████| 629/629 [02:14<00:00,  4.69it/s]\n"
     ]
    }
   ],
   "source": [
    "real_Y, predict_Y = [], []\n",
    "\n",
    "pbar = tqdm(range(0, len(test_input_ids), batch_size), desc = 'test minibatch loop')\n",
    "for i in pbar:\n",
    "    index = min(i + batch_size, len(test_input_ids))\n",
    "    batch_x = pad_sequences(test_input_ids[i : index], padding='post')\n",
    "    batch_y = test_Y[i : index]\n",
    "    batch_masks = pad_sequences(test_input_masks[i : index], padding='post', value = 1)\n",
    "    batch_segments = pad_sequences(test_segment_ids[i : index], padding='post', value = 4)\n",
    "    predict_Y += np.argmax(sess.run(model.logits,\n",
    "            feed_dict = {\n",
    "                model.Y: batch_y,\n",
    "                model.X: batch_x,\n",
    "                model.segment_ids: batch_segments,\n",
    "                model.input_masks: batch_masks\n",
    "            },\n",
    "    ), 1, ).tolist()\n",
    "    real_Y += batch_y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "not relevant    0.90742   0.86683   0.88666      1990\n",
      "    relevant    0.91528   0.94209   0.92849      3039\n",
      "\n",
      "    accuracy                        0.91231      5029\n",
      "   macro avg    0.91135   0.90446   0.90758      5029\n",
      "weighted avg    0.91217   0.91231   0.91194      5029\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from sklearn import metrics\n",
    "\n",
    "print(\n",
    "    metrics.classification_report(\n",
    "        real_Y, predict_Y, target_names = ['not relevant', 'relevant'],\n",
    "        digits = 5\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Placeholder',\n",
       " 'Placeholder_1',\n",
       " 'Placeholder_2',\n",
       " 'Placeholder_3',\n",
       " 'model/transformer/r_w_bias',\n",
       " 'model/transformer/r_r_bias',\n",
       " 'model/transformer/word_embedding/lookup_table',\n",
       " 'model/transformer/word_embedding/lookup_table_2',\n",
       " 'model/transformer/r_s_bias',\n",
       " 'model/transformer/seg_embed',\n",
       " 'model/transformer/layer_shared/rel_attn/q/kernel',\n",
       " 'model/transformer/layer_shared/rel_attn/k/kernel',\n",
       " 'model/transformer/layer_shared/rel_attn/v/kernel',\n",
       " 'model/transformer/layer_shared/rel_attn/r/kernel',\n",
       " 'model/transformer/layer_shared/rel_attn/o/kernel',\n",
       " 'model/transformer/layer_shared/rel_attn/LayerNorm/gamma',\n",
       " 'model/transformer/layer_shared/ff/layer_1/kernel',\n",
       " 'model/transformer/layer_shared/ff/layer_1/bias',\n",
       " 'model/transformer/layer_shared/ff/layer_2/kernel',\n",
       " 'model/transformer/layer_shared/ff/layer_2/bias',\n",
       " 'model/transformer/layer_shared/ff/LayerNorm/gamma',\n",
       " 'dense/kernel',\n",
       " 'dense/bias',\n",
       " 'logits_seq',\n",
       " 'logits']"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "strings = ','.join(\n",
    "    [\n",
    "        n.name\n",
    "        for n in tf.get_default_graph().as_graph_def().node\n",
    "        if ('Variable' in n.op\n",
    "        or 'Placeholder' in n.name\n",
    "        or 'logits' in n.name\n",
    "        or 'alphas' in n.name\n",
    "        or 'self/Softmax' in n.name)\n",
    "        and 'Adam' not in n.name\n",
    "        and 'beta' not in n.name\n",
    "        and 'global_step' not in n.name\n",
    "    ]\n",
    ")\n",
    "strings.split(',')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "def freeze_graph(model_dir, output_node_names):\n",
    "\n",
    "    if not tf.gfile.Exists(model_dir):\n",
    "        raise AssertionError(\n",
    "            \"Export directory doesn't exists. Please specify an export \"\n",
    "            'directory: %s' % model_dir\n",
    "        )\n",
    "\n",
    "    checkpoint = tf.train.get_checkpoint_state(model_dir)\n",
    "    input_checkpoint = checkpoint.model_checkpoint_path\n",
    "\n",
    "    absolute_model_dir = '/'.join(input_checkpoint.split('/')[:-1])\n",
    "    output_graph = absolute_model_dir + '/frozen_model.pb'\n",
    "    clear_devices = True\n",
    "    with tf.Session(graph = tf.Graph()) as sess:\n",
    "        saver = tf.train.import_meta_graph(\n",
    "            input_checkpoint + '.meta', clear_devices = clear_devices\n",
    "        )\n",
    "        saver.restore(sess, input_checkpoint)\n",
    "        output_graph_def = tf.graph_util.convert_variables_to_constants(\n",
    "            sess,\n",
    "            tf.get_default_graph().as_graph_def(),\n",
    "            output_node_names.split(','),\n",
    "        )\n",
    "        with tf.gfile.GFile(output_graph, 'wb') as f:\n",
    "            f.write(output_graph_def.SerializeToString())\n",
    "        print('%d ops in the final graph.' % len(output_graph_def.node))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from alxlnet-base-relevancy/model.ckpt\n",
      "WARNING:tensorflow:From <ipython-input-32-9a7215a4e58a>:23: convert_variables_to_constants (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.compat.v1.graph_util.convert_variables_to_constants`\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/framework/graph_util_impl.py:277: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.compat.v1.graph_util.extract_sub_graph`\n",
      "INFO:tensorflow:Froze 21 variables.\n",
      "INFO:tensorflow:Converted 21 variables to const ops.\n",
      "7390 ops in the final graph.\n"
     ]
    }
   ],
   "source": [
    "freeze_graph('alxlnet-base-relevancy', strings)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
